{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3b2e6ab7",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PreTrainedTokenizerFast(name_or_path='distilbert-base-uncased', vocab_size=30522, model_max_len=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'input_ids': [[101, 7592, 1010, 2023, 2003, 2034, 6251, 3975, 2046, 2616, 1012, 102], [101, 2023, 2003, 2117, 6251, 3975, 2046, 2616, 1012, 102]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "#加载编码器\n",
    "tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased',\n",
    "                                          use_fast=True)\n",
    "\n",
    "print(tokenizer)\n",
    "\n",
    "#编码试算\n",
    "tokenizer.batch_encode_plus([[\n",
    "    'Hello', ',', 'this', 'is', 'first', 'sentence', 'split', 'into', 'words',\n",
    "    '.'\n",
    "], ['This', 'is', 'second', 'sentence', 'split', 'into', 'words', '.']],\n",
    "                            is_split_into_words=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5e6c8ffc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'ner'"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#定义任务类型,这里可以取pos,chunk,ner\n",
    "task = 'ner'\n",
    "\n",
    "task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "cc4355f0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "查看数据样例\n",
      "{'id': '0', 'tokens': ['EU', 'rejects', 'German', 'call', 'to', 'boycott', 'British', 'lamb', '.'], 'pos_tags': [22, 42, 16, 21, 35, 37, 16, 21, 7], 'chunk_tags': [11, 21, 11, 12, 21, 22, 11, 12, 0], 'ner_tags': [3, 0, 7, 0, 0, 0, 7, 0, 0]}\n",
      "获取label的名字,这个后面计算评价指标的时候要用\n",
      "['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC']\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],\n",
       "        num_rows: 14042\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],\n",
       "        num_rows: 3251\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],\n",
       "        num_rows: 3454\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datasets import load_dataset, load_from_disk\n",
    "\n",
    "#加载数据\n",
    "#dataset = load_dataset(path='conll2003')\n",
    "dataset = load_from_disk('datas/conll2003')\n",
    "\n",
    "print('查看数据样例')\n",
    "print(dataset['train'][0])\n",
    "\n",
    "print('获取label的名字,这个后面计算评价指标的时候要用')\n",
    "label_name = dataset['train'].features['%s_tags' % task].feature.names\n",
    "print(label_name)\n",
    "\n",
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a1426890",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "就算设置了is_split_into_words=True,编码后的数字和原文里的单词也不一定是一一对应的\n",
      "以下面这个数据为例\n",
      "{'id': '2', 'tokens': ['BRUSSELS', '1996-08-22'], 'pos_tags': [22, 11], 'chunk_tags': [11, 12], 'ner_tags': [5, 0]}\n",
      "调用编码\n",
      "可以看到原文里只有2个单词,但是编码后是7个词\n",
      "{'input_ids': [[101, 9371, 2727, 1011, 5511, 1011, 2570, 102]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1]]}\n",
      "恢复成文本之后长这样\n",
      "[CLS] brussels 1996 - 08 - 22 [SEP]\n",
      "因为有这样的情况存在,所以编码后的数字和label不能做到一一对应\n",
      "但是可以在编码结果上使用word_ids函数得到每个编码对应的原文索引\n",
      "特殊标识的索引是None,很显然,因为特殊符号不来自于原文的任何位置\n",
      "[None, 0, 1, 1, 1, 1, 1, None]\n",
      "有了word_ids,就可以找到编码后的数字对应的label,可以做到一一对应\n",
      "特殊符号的label置为-100,这和模型的预训练情况有关\n",
      "[-100, 5, 0, 0, 0, 0, 0, -100]\n"
     ]
    }
   ],
   "source": [
    "#以下为原理说明性代码\n",
    "print('就算设置了is_split_into_words=True,编码后的数字和原文里的单词也不一定是一一对应的')\n",
    "print('以下面这个数据为例')\n",
    "data = dataset['train'][2]\n",
    "print(data)\n",
    "\n",
    "print('调用编码')\n",
    "data_encode = tokenizer.batch_encode_plus([data['tokens']],\n",
    "                                          is_split_into_words=True)\n",
    "\n",
    "print('可以看到原文里只有2个单词,但是编码后是7个词')\n",
    "print(data_encode)\n",
    "\n",
    "print('恢复成文本之后长这样')\n",
    "print(tokenizer.decode(data_encode['input_ids'][0]))\n",
    "\n",
    "print('因为有这样的情况存在,所以编码后的数字和label不能做到一一对应')\n",
    "print('但是可以在编码结果上使用word_ids函数得到每个编码对应的原文索引')\n",
    "print('特殊标识的索引是None,很显然,因为特殊符号不来自于原文的任何位置')\n",
    "print(data_encode.word_ids(batch_index=0))\n",
    "\n",
    "print('有了word_ids,就可以找到编码后的数字对应的label,可以做到一一对应')\n",
    "print('特殊符号的label置为-100,这和模型的预训练情况有关')\n",
    "label = [\n",
    "    -100 if i == None else data['ner_tags'][i]\n",
    "    for i in data_encode.word_ids(batch_index=0)\n",
    "]\n",
    "print(label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f7ba186f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at datas/conll2003/train/cache-87e556de43d77835.arrow\n",
      "Loading cached processed dataset at datas/conll2003/validation/cache-c304fc744fc5c076.arrow\n",
      "Loading cached processed dataset at datas/conll2003/test/cache-46393ddde082627e.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'input_ids': [101, 7327, 19164, 2446, 2655, 2000, 17757, 2329, 12559, 1012, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'labels': [-100, 3, 0, 7, 0, 0, 0, 7, 0, 0, -100]}\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['input_ids', 'attention_mask', 'labels'],\n",
       "        num_rows: 14042\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['input_ids', 'attention_mask', 'labels'],\n",
       "        num_rows: 3251\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['input_ids', 'attention_mask', 'labels'],\n",
       "        num_rows: 3454\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#根据以上说明性代码,写出这个数据处理函数\n",
    "def tokenize_and_align_labels(data):\n",
    "    #分词\n",
    "    data_encode = tokenizer.batch_encode_plus(data['tokens'],\n",
    "                                              truncation=True,\n",
    "                                              is_split_into_words=True)\n",
    "\n",
    "    data_encode['labels'] = []\n",
    "    for i in range(len(data['tokens'])):\n",
    "        label = []\n",
    "        for word_id in data_encode.word_ids(batch_index=i):\n",
    "            if word_id is None:\n",
    "                label.append(-100)\n",
    "            else:\n",
    "                label.append(data['%s_tags' % task][i][word_id])\n",
    "\n",
    "        data_encode['labels'].append(label)\n",
    "\n",
    "    return data_encode\n",
    "\n",
    "\n",
    "dataset = dataset.map(\n",
    "    tokenize_and_align_labels,\n",
    "    batched=True,\n",
    "    batch_size=1000,\n",
    "    num_proc=1,\n",
    "    remove_columns=['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'])\n",
    "\n",
    "print(dataset['train'][0])\n",
    "\n",
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "26000c13",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input_ids torch.Size([8, 27]) tensor([[ 101, 5548, 1017, 2624, 3799, 1014,  102,    0,    0,    0,    0,    0,\n",
      "            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,\n",
      "            0,    0,    0],\n",
      "        [ 101, 1000, 1996, 2142, 2983, 3504, 2830, 2000, 1996, 2220, 8085, 1997,\n",
      "         1996, 5036, 1010, 1000, 1996, 4861, 2056, 1012,  102,    0,    0,    0,\n",
      "            0,    0,    0]])\n",
      "attention_mask torch.Size([8, 27]) tensor([[1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,\n",
      "         0, 0, 0]])\n",
      "labels torch.Size([8, 27]) tensor([[-100,    3,    0,    3,    4,    0, -100, -100, -100, -100, -100, -100,\n",
      "         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
      "         -100, -100, -100],\n",
      "        [-100,    0,    0,    5,    6,    0,    0,    0,    0,    0,    0,    0,\n",
      "            0,    0,    0,    0,    0,    0,    0,    0, -100, -100, -100, -100,\n",
      "         -100, -100, -100]])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "1755"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "from transformers import DataCollatorForTokenClassification\n",
    "\n",
    "#数据加载器\n",
    "loader = torch.utils.data.DataLoader(\n",
    "    dataset=dataset['train'],\n",
    "    batch_size=8,\n",
    "    collate_fn=DataCollatorForTokenClassification(tokenizer),\n",
    "    shuffle=True,\n",
    "    drop_last=True,\n",
    ")\n",
    "\n",
    "for i, data in enumerate(loader):\n",
    "    break\n",
    "\n",
    "for k, v in data.items():\n",
    "    print(k, v.shape, v[:2])\n",
    "\n",
    "len(loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "72a6db43",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_transform.weight', 'vocab_projector.weight']\n",
      "- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForTokenClassification: ['vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_transform.weight', 'vocab_projector.weight']\n",
      "- This IS expected if you are initializing DistilBertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing DistilBertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of DistilBertForTokenClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6636.9801\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(tensor(2.3528, grad_fn=<NllLossBackward0>), torch.Size([8, 27, 9]))"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import AutoModelForTokenClassification, DistilBertModel\n",
    "\n",
    "#加载模型\n",
    "#model = AutoModelForTokenClassification.from_pretrained('distilbert-base-uncased', num_labels=len(label_name))\n",
    "\n",
    "\n",
    "#定义下游任务模型\n",
    "class Model(torch.nn.Module):\n",
    "    def __init__(self, num_labels):\n",
    "        super().__init__()\n",
    "        self.pretrained = DistilBertModel.from_pretrained(\n",
    "            'distilbert-base-uncased')\n",
    "\n",
    "        self.fc = torch.nn.Sequential(torch.nn.Dropout(0.1),\n",
    "                                      torch.nn.Linear(768, num_labels))\n",
    "\n",
    "        #加载预训练模型的参数\n",
    "        parameters = AutoModelForTokenClassification.from_pretrained(\n",
    "            'distilbert-base-uncased', num_labels=len(label_name))\n",
    "        self.fc[1].load_state_dict(parameters.classifier.state_dict())\n",
    "\n",
    "        self.criterion = torch.nn.CrossEntropyLoss()\n",
    "\n",
    "    def forward(self, input_ids, attention_mask, labels=None):\n",
    "        logits = self.pretrained(input_ids=input_ids,\n",
    "                                 attention_mask=attention_mask)\n",
    "        logits = logits.last_hidden_state\n",
    "\n",
    "        logits = self.fc(logits)\n",
    "\n",
    "        loss = None\n",
    "        if labels is not None:\n",
    "            loss = self.criterion(logits.flatten(end_dim=1), labels.flatten())\n",
    "\n",
    "        return {'loss': loss, 'logits': logits}\n",
    "\n",
    "\n",
    "model = Model(num_labels=len(label_name))\n",
    "\n",
    "#统计参数量\n",
    "print(sum(i.numel() for i in model.parameters()) / 10000)\n",
    "\n",
    "out = model(**data)\n",
    "\n",
    "out['loss'], out['logits'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "abca355c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_metric\n",
    "\n",
    "#加载评价函数\n",
    "#metric = load_metric('seqeval')\n",
    "\n",
    "#要计算必须把label转换成名字形态,有点蛋疼\n",
    "#metric.compute(predictions=[label_name], references=[label_name])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "c924a315",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "10\n",
      "20\n",
      "30\n",
      "40\n",
      "50\n",
      "0.027704584807648876\n"
     ]
    }
   ],
   "source": [
    "#测试\n",
    "def test():\n",
    "    model.eval()\n",
    "\n",
    "    #数据加载器\n",
    "    loader_test = torch.utils.data.DataLoader(\n",
    "        dataset=dataset['test'],\n",
    "        batch_size=16,\n",
    "        collate_fn=DataCollatorForTokenClassification(tokenizer),\n",
    "        shuffle=True,\n",
    "        drop_last=True,\n",
    "    )\n",
    "\n",
    "    labels = []\n",
    "    outs = []\n",
    "    for i, data in enumerate(loader_test):\n",
    "        #计算\n",
    "        with torch.no_grad():\n",
    "            out = model(**data)\n",
    "\n",
    "        out = out['logits'].argmax(dim=2)\n",
    "\n",
    "        for j in range(16):\n",
    "            #使用attention_mask筛选label,很显然,不需要pad的预测结果\n",
    "            #另外首尾两个特殊符号也不需要预测结果\n",
    "            select = data['attention_mask'][j] == 1\n",
    "            labels.append(data['labels'][j][select][1:-1])\n",
    "            outs.append(out[j][select][1:-1])\n",
    "\n",
    "        if i % 10 == 0:\n",
    "            print(i)\n",
    "\n",
    "        if i == 50:\n",
    "            break\n",
    "\n",
    "    #计算评价指标,要计算必须把label转换成名字形态,有点蛋疼\n",
    "    \"\"\"labels_name = [[label_name[j] for j in i] for i in labels]\n",
    "    outs_name = [[label_name[j] for j in i] for i in outs]\n",
    "\n",
    "    metric_out = metric.compute(predictions=[outs_name],\n",
    "                                references=[labels_name])\n",
    "\n",
    "    for k in [\n",
    "            'overall_precision', 'overall_recall', 'overall_f1',\n",
    "            'overall_accuracy'\n",
    "    ]:\n",
    "        print(k, metric_out[k])\"\"\"\n",
    "\n",
    "    #计算正确率\n",
    "    labels = torch.cat(labels)\n",
    "    outs = torch.cat(outs)\n",
    "\n",
    "    print((labels == outs).sum().item() / len(labels))\n",
    "\n",
    "\n",
    "test()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "1375fcad",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/anaconda3/envs/cpu/lib/python3.6/site-packages/transformers/optimization.py:309: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
      "  FutureWarning,\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 2.3589563369750977 0.0379746835443038 1.998860398860399e-05\n",
      "50 0.9524945020675659 0.652542372881356 1.941880341880342e-05\n",
      "100 0.3112027943134308 0.9279279279279279 1.8849002849002852e-05\n",
      "150 0.30391865968704224 0.925 1.827920227920228e-05\n",
      "200 0.106633260846138 0.9844961240310077 1.770940170940171e-05\n",
      "250 0.19684122502803802 0.9523809523809523 1.713960113960114e-05\n",
      "300 0.18589724600315094 0.9150943396226415 1.6569800569800573e-05\n",
      "350 0.20196762681007385 0.912621359223301 1.6000000000000003e-05\n",
      "400 0.09656360000371933 0.9636363636363636 1.5430199430199432e-05\n",
      "450 0.04225413128733635 0.9852941176470589 1.4860398860398862e-05\n",
      "500 0.06595474481582642 0.9791666666666666 1.4290598290598293e-05\n",
      "550 0.1935502141714096 0.959349593495935 1.3720797720797722e-05\n",
      "600 0.062100253999233246 0.98 1.3150997150997152e-05\n",
      "650 0.12392968684434891 0.9672131147540983 1.2581196581196581e-05\n",
      "700 0.04459675773978233 0.9878048780487805 1.2011396011396012e-05\n",
      "750 0.1767694503068924 0.9523809523809523 1.1441595441595444e-05\n",
      "800 0.013531548902392387 1.0 1.0871794871794871e-05\n",
      "850 0.040851347148418427 0.9910714285714286 1.0301994301994303e-05\n",
      "900 0.009358435869216919 1.0 9.732193732193734e-06\n",
      "950 0.06090879067778587 0.9811320754716981 9.162393162393163e-06\n",
      "1000 0.22171254456043243 0.94 8.592592592592593e-06\n",
      "1050 0.04289395362138748 0.9819277108433735 8.022792022792024e-06\n",
      "1100 0.12170761078596115 0.9642857142857143 7.452991452991454e-06\n",
      "1150 0.5125896334648132 0.883495145631068 6.883190883190884e-06\n",
      "1200 0.06157579645514488 0.9938650306748467 6.313390313390314e-06\n",
      "1250 0.26822757720947266 0.937007874015748 5.743589743589743e-06\n",
      "1300 0.17063970863819122 0.94 5.173789173789175e-06\n",
      "1350 0.11368001997470856 0.9578947368421052 4.603988603988604e-06\n",
      "1400 0.09050043672323227 0.9782608695652174 4.034188034188034e-06\n",
      "1450 0.05749855935573578 0.983739837398374 3.4643874643874647e-06\n",
      "1500 0.23470167815685272 0.9470198675496688 2.894586894586895e-06\n",
      "1550 0.08962950855493546 0.9779411764705882 2.324786324786325e-06\n",
      "1600 0.29448533058166504 0.9436619718309859 1.7549857549857553e-06\n",
      "1650 0.002838065382093191 1.0 1.1851851851851854e-06\n",
      "1700 0.29472947120666504 0.9259259259259259 6.153846153846155e-07\n",
      "1750 0.04722971096634865 0.9782608695652174 4.5584045584045586e-08\n"
     ]
    }
   ],
   "source": [
    "from transformers import AdamW\n",
    "from transformers.optimization import get_scheduler\n",
    "\n",
    "\n",
    "#训练\n",
    "def train():\n",
    "    optimizer = AdamW(model.parameters(), lr=2e-5)\n",
    "    scheduler = get_scheduler(name='linear',\n",
    "                              num_warmup_steps=0,\n",
    "                              num_training_steps=len(loader),\n",
    "                              optimizer=optimizer)\n",
    "\n",
    "    model.train()\n",
    "    for i, data in enumerate(loader):\n",
    "        out = model(**data)\n",
    "        loss = out['loss']\n",
    "\n",
    "        loss.backward()\n",
    "        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
    "\n",
    "        optimizer.step()\n",
    "        scheduler.step()\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        model.zero_grad()\n",
    "\n",
    "        if i % 50 == 0:\n",
    "            labels = []\n",
    "            outs = []\n",
    "            out = out['logits'].argmax(dim=2)\n",
    "            for j in range(8):\n",
    "                #使用attention_mask筛选label,很显然,不需要pad的预测结果\n",
    "                #另外首尾两个特殊符号也不需要预测结果\n",
    "                select = data['attention_mask'][j] == 1\n",
    "                labels.append(data['labels'][j][select][1:-1])\n",
    "                outs.append(out[j][select][1:-1])\n",
    "\n",
    "            #不知道为什么要会打印这么多warning,干脆不算了\n",
    "            \"\"\"#计算评价指标,要计算必须把label转换成名字形态,有点蛋疼\n",
    "            labels_name = [[label_name[j] for j in i] for i in labels]\n",
    "            outs_name = [[label_name[j] for j in i] for i in outs]\n",
    "            \n",
    "            metric_out = metric.compute(predictions=[outs_name],\n",
    "                                        references=[labels_name])\n",
    "\n",
    "            for k in [\n",
    "                    'overall_precision', 'overall_recall', 'overall_f1',\n",
    "                    'overall_accuracy'\n",
    "            ]:\n",
    "                print(k, metric_out[k])\"\"\"\n",
    "\n",
    "            #计算正确率\n",
    "            labels = torch.cat(labels)\n",
    "            outs = torch.cat(outs)\n",
    "            accuracy = (labels == outs).sum().item() / len(labels)\n",
    "\n",
    "            lr = optimizer.state_dict()['param_groups'][0]['lr']\n",
    "\n",
    "            print(i, loss.item(), accuracy, lr)\n",
    "\n",
    "    torch.save(model, 'models/6.命名实体识别_%s.model' % task)\n",
    "\n",
    "\n",
    "train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "ed397281",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "10\n",
      "20\n",
      "30\n",
      "40\n",
      "50\n",
      "0.9709837743020027\n"
     ]
    }
   ],
   "source": [
    "model = torch.load('models/6.命名实体识别_%s.model' % task)\n",
    "test()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
